{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generative Adversarial Networks\n",
    "\n",
    "Generative adversarial networks (GANs) are a powerful approach for\n",
    "probabilistic modeling (I. Goodfellow et al., 2014; I. Goodfellow, 2016).\n",
    "They posit a deep generative model and they enable fast and accurate\n",
    "inferences.\n",
    "\n",
    "We demonstrate with an example in Edward. A webpage version is available at\n",
    "http://edwardlib.org/tutorials/gan."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import print_function\n",
    "from __future__ import division\n",
    "\n",
    "import edward as ed\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "import numpy as np\n",
    "import os\n",
    "import tensorflow as tf\n",
    "\n",
    "from edward.models import Uniform\n",
    "from observations import mnist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def plot(samples):\n",
    "  fig = plt.figure(figsize=(4, 4))\n",
    "  gs = gridspec.GridSpec(4, 4)\n",
    "  gs.update(wspace=0.05, hspace=0.05)\n",
    "\n",
    "  for i, sample in enumerate(samples):\n",
    "    ax = plt.subplot(gs[i])\n",
    "    plt.axis('off')\n",
    "    ax.set_xticklabels([])\n",
    "    ax.set_yticklabels([])\n",
    "    ax.set_aspect('equal')\n",
    "    plt.imshow(sample.reshape(28, 28), cmap='Greys_r')\n",
    "\n",
    "  return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "ed.set_seed(42)\n",
    "\n",
    "data_dir = \"/tmp/data\"\n",
    "out_dir = \"/tmp/out\"\n",
    "if not os.path.exists(out_dir):\n",
    "  os.makedirs(out_dir)\n",
    "M = 128  # batch size during training\n",
    "d = 100  # latent dimension"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data\n",
    "\n",
    "We use training data from MNIST, which consists of 55,000 $28\\times\n",
    "28$ pixel images (LeCun, Bottou, Bengio, & Haffner, 1998). Each image is represented\n",
    "as a flattened vector of 784 elements, and each element is a pixel\n",
    "intensity between 0 and 1.\n",
    "\n",
    "![GAN Fig 0](https://raw.githubusercontent.com/blei-lab/edward/master/docs/images/gan-fig0.png)\n",
    "\n",
    "\n",
    "The goal is to build and infer a model that can generate high quality\n",
    "images of handwritten digits.\n",
    "\n",
    "During training we will feed batches of MNIST digits. We instantiate a\n",
    "TensorFlow placeholder with a fixed batch size of $M$ images.\n",
    "\n",
    "We also define a helper function to select the next batch of data\n",
    "points from the full set of examples. It keeps track of the current\n",
    "batch index and returns the next batch using the function ``next()``.\n",
    "We will generate batches from `x_train_generator` during inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generator(array, batch_size):\n",
    "  \"\"\"Generate batch with respect to array's first axis.\"\"\"\n",
    "  start = 0  # pointer to where we are in iteration\n",
    "  while True:\n",
    "    stop = start + batch_size\n",
    "    diff = stop - array.shape[0]\n",
    "    if diff <= 0:\n",
    "      batch = array[start:stop]\n",
    "      start += batch_size\n",
    "    else:\n",
    "      batch = np.concatenate((array[start:], array[:diff]))\n",
    "      start = diff\n",
    "    batch = batch.astype(np.float32) / 255.0  # normalize pixel intensities\n",
    "    batch = np.random.binomial(1, batch)  # binarize images\n",
    "    yield batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">> Downloading train-images-idx3-ubyte.gz 100.1%\n",
      "Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n",
      ">> Downloading train-labels-idx1-ubyte.gz 113.5%\n",
      "Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n",
      ">> Downloading t10k-images-idx3-ubyte.gz 100.4%\n",
      "Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n",
      ">> Downloading t10k-labels-idx1-ubyte.gz 180.4%\n",
      "Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n"
     ]
    }
   ],
   "source": [
    "(x_train, _), (x_test, _) = mnist(data_dir)\n",
    "x_train_generator = generator(x_train, M)\n",
    "x_ph = tf.placeholder(tf.float32, [M, 784])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model\n",
    "\n",
    "GANs posit generative models using an implicit mechanism. Given some\n",
    "random noise, the data is assumed to be generated by a deterministic\n",
    "function of that noise.\n",
    "\n",
    "Formally, the generative process is\n",
    "\n",
    "\\begin{align*}\n",
    "\\mathbf{\\epsilon} &\\sim p(\\mathbf{\\epsilon}), \\\\\n",
    "\\mathbf{x} &= G(\\mathbf{\\epsilon}; \\theta),\n",
    "\\end{align*}\n",
    "\n",
    "where $G(\\cdot; \\theta)$ is a neural network that takes the samples\n",
    "$\\mathbf{\\epsilon}$ as input. The distribution\n",
    "$p(\\mathbf{\\epsilon})$ is interpreted as random noise injected to\n",
    "produce stochasticity in a physical system; it is typically a fixed\n",
    "uniform or normal distribution with some latent dimensionality.\n",
    "\n",
    "In Edward, we build the model as follows, using `tf.layers` to\n",
    "specify the neural network. It defines a 2-layer fully connected neural\n",
    "network and outputs a vector of length $28\\times28$ with values in\n",
    "$[0,1]$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generative_network(eps):\n",
    "  net = tf.layers.dense(eps, 128, activation=tf.nn.relu)\n",
    "  net = tf.layers.dense(net, 784, activation=tf.sigmoid)\n",
    "  return net\n",
    "\n",
    "with tf.variable_scope(\"Gen\"):\n",
    "  eps = Uniform(tf.zeros([M, d]) - 1.0, tf.ones([M, d]))\n",
    "  x = generative_network(eps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We aim to estimate parameters of the generative network such\n",
    "that the model best captures the data. (Note in GANs, we are\n",
    "interested only in parameter estimation and not inference about any\n",
    "latent variables.)\n",
    "\n",
    "Unfortunately, probability models described above do not admit a tractable\n",
    "likelihood. This poses a problem for most inference algorithms, as\n",
    "they usually require taking the model's density.  Thus we are\n",
    "motivated to use \"likelihood-free\" algorithms\n",
    "(Marin, Pudlo, Robert, & Ryder, 2012), a class of methods which assume one\n",
    "can only sample from the model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference\n",
    "\n",
    "A key idea in likelihood-free methods is to learn by\n",
    "comparison (e.g., Rubin (1984; Gretton, Borgwardt, Rasch, Schölkopf, & Smola, 2012)): by\n",
    "analyzing the discrepancy between samples from the model and samples\n",
    "from the true data distribution, we have information on where the\n",
    "model can be improved in order to generate better samples.\n",
    "\n",
    "In GANs, a neural network $D(\\cdot;\\phi)$ makes this comparison,\n",
    "known as the discriminator.\n",
    "$D(\\cdot;\\phi)$ takes data $\\mathbf{x}$ as input (either\n",
    "generations from the model or data points from the data set), and it\n",
    "calculates the probability that $\\mathbf{x}$ came from the true data.\n",
    "\n",
    "In Edward, we use the following discriminative network. It is simply a\n",
    "feedforward network with one ReLU hidden layer. It returns the\n",
    "probability in the logit (unconstrained) scale."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def discriminative_network(x):\n",
    "  \"\"\"Outputs probability in logits.\"\"\"\n",
    "  net = tf.layers.dense(x, 128, activation=tf.nn.relu)\n",
    "  net = tf.layers.dense(net, 1, activation=None)\n",
    "  return net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let $p^*(\\mathbf{x})$ represent the true data distribution.\n",
    "The optimization problem used in GANs is\n",
    "\n",
    "\\begin{equation*}\n",
    "\\min_\\theta \\max_\\phi~\n",
    "\\mathbb{E}_{p^*(\\mathbf{x})} [ \\log D(\\mathbf{x}; \\phi) ]\n",
    "+ \\mathbb{E}_{p(\\mathbf{x}; \\theta)} [ \\log (1 - D(\\mathbf{x}; \\phi)) ].\n",
    "\\end{equation*}\n",
    "\n",
    "This optimization problem is bilevel: it requires a minima solution\n",
    "with respect to generative parameters and a maxima solution with\n",
    "respect to discriminative parameters.\n",
    "In practice, the algorithm proceeds by iterating gradient updates on\n",
    "each. An\n",
    "additional heuristic also modifies the objective function for the\n",
    "generative model in order to avoid saturation of gradients\n",
    "(I. J. Goodfellow, 2014).\n",
    "\n",
    "Many sources of intuition exist behind GAN-style training. One, which\n",
    "is the original motivation, is based on idea that the two neural\n",
    "networks are playing a game. The discriminator tries to best\n",
    "distinguish samples away from the generator. The generator tries\n",
    "to produce samples that are indistinguishable by the discriminator.\n",
    "The goal of training is to reach a Nash equilibrium.\n",
    "\n",
    "Another source is the idea of casting unsupervised learning as\n",
    "supervised learning\n",
    "(M. U. Gutmann, Dutta, Kaski, & Corander, 2014; M. Gutmann & Hyvärinen, 2010).\n",
    "This allows one to leverage the power of classification—a problem that\n",
    "in recent years is (relatively speaking) very easy.\n",
    "\n",
    "A third comes from classical statistics, where the discriminator is\n",
    "interpreted as a proxy of the density ratio between the true data\n",
    "distribution and the model\n",
    " (Mohamed & Lakshminarayanan, 2016; Sugiyama, Suzuki, & Kanamori, 2012). By augmenting an\n",
    "original problem that may require the model's density with a\n",
    "discriminator (such as maximum likelihood), one can recover the\n",
    "original problem when the discriminator is optimal. Furthermore, this\n",
    "approximation is very fast, and it justifies GANs from the perspective\n",
    "of approximate inference.\n",
    "\n",
    "In Edward, the GAN algorithm (`GANInference`) simply takes the\n",
    "implicit density model on `x` as input, binded to its\n",
    "realizations `x_ph`. In addition, a parameterized function\n",
    "`discriminator` is provided to distinguish their\n",
    "samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "inference = ed.GANInference(\n",
    "    data={x: x_ph}, discriminator=discriminative_network)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll use ADAM as optimizers for both the generator and discriminator.\n",
    "We'll run the algorithm for 15,000 iterations and print progress every\n",
    "1,000 iterations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "optimizer = tf.train.AdamOptimizer()\n",
    "optimizer_d = tf.train.AdamOptimizer()\n",
    "\n",
    "inference = ed.GANInference(\n",
    "    data={x: x_ph}, discriminator=discriminative_network)\n",
    "inference.initialize(\n",
    "    optimizer=optimizer, optimizer_d=optimizer_d,\n",
    "    n_iter=15000, n_print=1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now form the main loop which trains the GAN. At each iteration, it\n",
    "takes a minibatch and updates the parameters according to the\n",
    "algorithm. At every 1000 iterations, it will print progress and also\n",
    "saves a figure of generated samples from the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration     1 [  0%]: Gen Loss = 0.667: Disc Loss = 1.303\n",
      "Iteration  1000 [  6%]: Gen Loss = 14.898: Disc Loss = 0.026\n",
      "Iteration  2000 [ 13%]: Gen Loss = 4.941: Disc Loss = 0.034\n",
      "Iteration  3000 [ 20%]: Gen Loss = 5.195: Disc Loss = 0.054\n",
      "Iteration  4000 [ 26%]: Gen Loss = 4.918: Disc Loss = 0.107\n",
      "Iteration  5000 [ 33%]: Gen Loss = 4.584: Disc Loss = 0.165\n",
      "Iteration  6000 [ 40%]: Gen Loss = 4.029: Disc Loss = 0.221\n",
      "Iteration  7000 [ 46%]: Gen Loss = 3.286: Disc Loss = 0.594\n",
      "Iteration  8000 [ 53%]: Gen Loss = 3.374: Disc Loss = 0.312\n",
      "Iteration  9000 [ 60%]: Gen Loss = 2.803: Disc Loss = 0.611\n",
      "Iteration 10000 [ 66%]: Gen Loss = 2.362: Disc Loss = 0.745\n",
      "Iteration 11000 [ 73%]: Gen Loss = 2.974: Disc Loss = 0.526\n",
      "Iteration 12000 [ 80%]: Gen Loss = 2.787: Disc Loss = 0.514\n",
      "Iteration 13000 [ 86%]: Gen Loss = 2.413: Disc Loss = 0.852\n",
      "Iteration 14000 [ 93%]: Gen Loss = 1.990: Disc Loss = 0.696\n",
      "Iteration 15000 [100%]: Gen Loss = 1.929: Disc Loss = 0.781\n"
     ]
    }
   ],
   "source": [
    "sess = ed.get_session()\n",
    "tf.global_variables_initializer().run()\n",
    "\n",
    "idx = np.random.randint(M, size=16)\n",
    "i = 0\n",
    "for t in range(inference.n_iter):\n",
    "  if t % inference.n_print == 0:\n",
    "    samples = sess.run(x)\n",
    "    samples = samples[idx, ]\n",
    "\n",
    "    fig = plot(samples)\n",
    "    plt.savefig(os.path.join(out_dir, '{}.png').format(\n",
    "        str(i).zfill(3)), bbox_inches='tight')\n",
    "    plt.close(fig)\n",
    "    i += 1\n",
    "\n",
    "  x_batch = next(x_train_generator)\n",
    "  info_dict = inference.update(feed_dict={x_ph: x_batch})\n",
    "  inference.print_progress(info_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Examining convergence of the GAN objective can be meaningless in\n",
    "practice. The algorithm is usually run until some other criterion is\n",
    "satisfied, such as if the samples look visually okay, or if the GAN\n",
    "can capture meaningful parts of the data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Criticism\n",
    "\n",
    "Evaluation of GANs remains an open problem---both in criticizing their\n",
    "fit to data and in assessing convergence.\n",
    "Recent advances have considered alternative objectives and\n",
    "heuristics to stabilize training (see also Soumith Chintala's\n",
    "[GAN hacks repo](https://github.com/soumith/ganhacks)).\n",
    "\n",
    "As one approach to criticize the model, we simply look at generated\n",
    "images during training. Below we show generations after 14,000\n",
    "iterations (that is, 14,000 gradient updates of both the generator and\n",
    "the discriminator).\n",
    "\n",
    "![GAN Fig 1](https://raw.githubusercontent.com/blei-lab/edward/master/docs/images/gan-fig1.png)\n",
    "                                  \n",
    "The images are meaningful albeit a little blurry. Suggestions for\n",
    "further improvements would be to tune the hyperparameters in the\n",
    "optimization, to improve the capacity of the discriminative and\n",
    "generative networks, and to leverage more prior information (such as\n",
    "convolutional architectures)."
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
